{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import os, sys\n",
    "sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(\"./\")))))\n",
    "\n",
    "from training import Agent\n",
    "from loss import params, loss_fn, evaluate_fn\n",
    "from data import generate_dataset, generate_batch_fn\n",
    "\n",
    "import config\n",
    "config.n_data = {\n",
    "    \"i\": 100,\n",
    "    \"b\": 100,\n",
    "    \"cx\": 101,\n",
    "    \"ct\": 101,\n",
    "    \"dx\": 20,\n",
    "    \"dt\": 20,\n",
    "}\n",
    "config.batch_size = {\n",
    "    \"dirichlet\": (config.n_data[\"i\"]+2*config.n_data[\"b\"]+config.n_data[\"dx\"]*config.n_data[\"dt\"]),\n",
    "    \"collocation\": (config.n_data[\"dx\"]*config.n_data[\"dt\"]+config.n_data[\"cx\"]*config.n_data[\"ct\"]),\n",
    "}\n",
    "config.iterations = 20000\n",
    "config.print_every = 200\n",
    "config.lr = 1e-3\n",
    "config.weights = {\n",
    "\t\"c1\": 1.0,\n",
    "\t\"c2\": 1.0,\n",
    "\t\"d1\": 1.0,\n",
    "\t\"d2\": 10.0,\n",
    "\t\"l1\": 1e-4,\n",
    "\t\"l2\": 1e-4,\n",
    "}\n",
    "\n",
    "datasets = generate_dataset(config.n_data[\"i\"], config.n_data[\"b\"], config.n_data[\"cx\"], config.n_data[\"ct\"], config.n_data[\"dx\"], config.n_data[\"dt\"])\n",
    "batch_fn, evaluate_batch_fn = generate_batch_fn(config.key, config.batch_size, *datasets, config.weights)\n",
    "\n",
    "agent = Agent(params, loss_fn, evaluate_fn, \"models/{}\".format(config.NAME))\n",
    "agent.compile(config.optimizer, config.lr)\n",
    "agent.train(config.iterations, batch_fn, evaluate_batch_fn, config.print_every, config.save_every, config.loss_names, config.log_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from loss import inverse_model, direct_model\n",
    "from data import domain\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "inverse_params = agent.params[1]\n",
    "x_test = jnp.linspace(*domain[:, 0], 100).reshape((-1, 1))\n",
    "a_fn = lambda x: 1+2/np.pi*np.cos(2*np.pi*x)\n",
    "a_pred = inverse_model(inverse_params, x_test)\n",
    "a_true = a_fn(x_test)\n",
    "\n",
    "direct_params = agent.params[0]\n",
    "t_test = domain[1, 1]*jnp.ones_like(x_test)\n",
    "uv_pred = direct_model(direct_params, jnp.hstack([x_test, t_test]))\n",
    "\n",
    "from scipy.io import loadmat\n",
    "data_true = loadmat(\"problem2_2_snapshot_epsilon_1e-12.mat\")\n",
    "u_true, v_true = data_true[\"u_snapshots\"][:, -1], data_true[\"v_snapshots\"][:, -1]\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib notebook\n",
    "\n",
    "f, ax = plt.subplots(1, 3, figsize = (15, 5))\n",
    "ax[0].plot(x_test, a_pred, label = \"pred\")\n",
    "ax[0].plot(x_test, a_true, label = \"true\")\n",
    "ax[0].set_title(\"a\")\n",
    "ax[1].plot(x_test, uv_pred[:, 0:1], label = \"pred\")\n",
    "ax[1].plot(x_test, u_true, label = \"true\")\n",
    "ax[1].set_title(\"u\")\n",
    "ax[2].plot(x_test, uv_pred[:, 1:2], label = \"pred\")\n",
    "ax[2].plot(x_test, v_true, label = \"true\")\n",
    "ax[2].set_title(\"v\")\n",
    "for ax_ in ax:\n",
    "\tax_.legend()\n",
    "\tax_.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.params[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
